import os
import glob
import numpy as np
import pandas as pd
import tkinter as tk
from tkinter import filedialog, messagebox
from tkinter import ttk

import mne
import matplotlib.pyplot as plt


# ============================================================
# DEFAULT OUTLIER SETTINGS (edit in GUI)
# ============================================================

DEFAULT_PTP_THRESH_UV = 150.0        # peak-to-peak threshold (µV); blank disables
DEFAULT_MAXABS_THRESH_UV = 100.0     # max abs threshold (µV); blank disables

DEFAULT_RMS_ROBUST_Z_THRESH = 3.5    # robust z threshold on epoch RMS

DEFAULT_USE_CH_RMS_ROBUST_Z = True
DEFAULT_CH_RMS_ROBUST_Z_THRESH = 3.5

EPOCH_FIF_GLOB = "*_cleaned-epo.fif"

OUT_SUFFIX = "_manualOutliersRemoved-epo.fif"


# ============================================================
# ROBUST Z
# ============================================================
def robust_z(x: np.ndarray):
    x = np.asarray(x, dtype=float)
    med = np.nanmedian(x)
    mad = np.nanmedian(np.abs(x - med))
    if mad == 0 or not np.isfinite(mad):
        mu = np.nanmean(x)
        sd = np.nanstd(x)
        if sd == 0 or not np.isfinite(sd):
            return np.zeros_like(x)
        return (x - mu) / sd
    return 0.6745 * (x - med) / mad


# ============================================================
# OUTLIER DETECTION (FLAGS ONLY; NO DROPPING)
# ============================================================
def flag_outlier_epochs(data_v: np.ndarray, ch_names,
                        ptp_thresh_uv,
                        maxabs_thresh_uv,
                        rms_robust_z_thresh,
                        use_ch_rms_robust_z,
                        ch_rms_robust_z_thresh):
    """
    data_v: (n_epochs, n_ch, n_times) in VOLTS
    returns:
      flag_mask (n_epochs bool),
      per_epoch_summary dict arrays,
      reasons list[list[str]]
    """
    n_epochs, n_ch, _ = data_v.shape
    data_uv = data_v * 1e6

    ptp_uv = data_uv.max(axis=2) - data_uv.min(axis=2)      # (n_epochs, n_ch)
    maxabs_uv = np.max(np.abs(data_uv), axis=2)             # (n_epochs, n_ch)
    rms_uv = np.sqrt(np.mean(data_uv ** 2, axis=2))         # (n_epochs, n_ch)
    epoch_rms_uv = np.sqrt(np.mean(data_uv ** 2, axis=(1, 2)))

    epoch_rz = robust_z(epoch_rms_uv)

    ch_rz = None
    if use_ch_rms_robust_z:
        ch_rz = np.zeros_like(rms_uv)
        for c in range(n_ch):
            ch_rz[:, c] = robust_z(rms_uv[:, c])

    flag_mask = np.zeros(n_epochs, dtype=bool)
    reasons = [[] for _ in range(n_epochs)]

    # PTP rule
    if ptp_thresh_uv is not None:
        bad = np.any(ptp_uv > float(ptp_thresh_uv), axis=1)
        for ei in np.where(bad)[0]:
            flag_mask[ei] = True
            bad_ch = np.where(ptp_uv[ei] > float(ptp_thresh_uv))[0]
            bad_lbl = [ch_names[i] for i in bad_ch[:8]]
            reasons[ei].append(f"PTP>{ptp_thresh_uv}uV ({', '.join(bad_lbl)})")

    # MAXABS rule
    if maxabs_thresh_uv is not None:
        bad = np.any(maxabs_uv > float(maxabs_thresh_uv), axis=1)
        for ei in np.where(bad)[0]:
            flag_mask[ei] = True
            bad_ch = np.where(maxabs_uv[ei] > float(maxabs_thresh_uv))[0]
            bad_lbl = [ch_names[i] for i in bad_ch[:8]]
            reasons[ei].append(f"MAXABS>{maxabs_thresh_uv}uV ({', '.join(bad_lbl)})")

    # Epoch RMS robust z
    bad = np.abs(epoch_rz) > float(rms_robust_z_thresh)
    for ei in np.where(bad)[0]:
        flag_mask[ei] = True
        reasons[ei].append(f"EpochRMS_rZ>{rms_robust_z_thresh} (z={epoch_rz[ei]:.2f})")

    # Channel RMS robust z
    if use_ch_rms_robust_z and ch_rz is not None:
        bad = np.any(np.abs(ch_rz) > float(ch_rms_robust_z_thresh), axis=1)
        for ei in np.where(bad)[0]:
            flag_mask[ei] = True
            bad_ch = np.where(np.abs(ch_rz[ei]) > float(ch_rms_robust_z_thresh))[0]
            bad_lbl = [ch_names[i] for i in bad_ch[:8]]
            reasons[ei].append(f"ChRMS_rZ>{ch_rms_robust_z_thresh} ({', '.join(bad_lbl)})")

    per_epoch = {
        "epoch_rms_uv": epoch_rms_uv,
        "epoch_rms_robustz": epoch_rz,
        # keep these summaries too (max across channels), helpful for review
        "ptp_max_uv": ptp_uv.max(axis=1),
        "maxabs_max_uv": maxabs_uv.max(axis=1),
    }
    return flag_mask, per_epoch, reasons


# ============================================================
# MANUAL REVIEW DIALOG
# ============================================================
class EpochReviewDialog(tk.Toplevel):
    """
    Review flagged epochs one-by-one for a participant.
    Shows epoch plot and reasons. User chooses drop/keep.

    Buttons:
      Drop, Keep, Drop all remaining, Keep all remaining, Stop participant
    """
    def __init__(self, parent, epochs: mne.Epochs, flagged_indices, reasons, metrics, participant_name):
        super().__init__(parent)
        self.title(f"Manual Epoch Review - {participant_name}")
        self.resizable(False, False)

        self.epochs = epochs
        self.flagged = list(flagged_indices)
        self.reasons = reasons
        self.metrics = metrics
        self.participant = participant_name

        self.decisions = {}  # epoch_idx -> "drop" or "keep"
        self._mode = None    # "drop_all" or "keep_all" or None
        self._stop = False

        self._i = 0

        self.info_var = tk.StringVar(value="")
        self.reason_var = tk.StringVar(value="")
        self.metric_var = tk.StringVar(value="")

        container = ttk.Frame(self, padding=10)
        container.pack(fill="both", expand=True)

        ttk.Label(container, textvariable=self.info_var, font=("Segoe UI", 10, "bold")).pack(anchor="w")
        ttk.Label(container, textvariable=self.reason_var, wraplength=650).pack(anchor="w", pady=(6, 0))
        ttk.Label(container, textvariable=self.metric_var, wraplength=650).pack(anchor="w", pady=(6, 0))

        btns = ttk.Frame(container)
        btns.pack(fill="x", pady=(10, 0))

        ttk.Button(btns, text="Keep", command=self.keep).pack(side="left")
        ttk.Button(btns, text="Drop", command=self.drop).pack(side="left", padx=6)
        ttk.Button(btns, text="Keep all remaining", command=self.keep_all).pack(side="left", padx=20)
        ttk.Button(btns, text="Drop all remaining", command=self.drop_all).pack(side="left", padx=6)
        ttk.Button(btns, text="Stop participant", command=self.stop).pack(side="right")

        self.protocol("WM_DELETE_WINDOW", self.stop)

        self.transient(parent)
        self.grab_set()

        # Kick off first epoch
        self._show_current()

        self.wait_window()

    def _show_current(self):
        if self._stop:
            self._finish()
            return

        if self._i >= len(self.flagged):
            self._finish()
            return

        idx = self.flagged[self._i]

        # If user set bulk mode, apply it without plotting
        if self._mode == "keep_all":
            self.decisions[idx] = "keep"
            self._i += 1
            self.after(1, self._show_current)
            return
        if self._mode == "drop_all":
            self.decisions[idx] = "drop"
            self._i += 1
            self.after(1, self._show_current)
            return

        # Update labels
        self.info_var.set(f"Epoch {idx}   ({self._i+1}/{len(self.flagged)} flagged)")
        rs = "; ".join(self.reasons[idx]) if self.reasons[idx] else "(no reasons listed)"
        self.reason_var.set(f"Reasons: {rs}")

        m = self.metrics
        self.metric_var.set(
            f"Metrics: PTPmax={m['ptp_max_uv'][idx]:.1f} µV, "
            f"MAXABSmax={m['maxabs_max_uv'][idx]:.1f} µV, "
            f"EpochRMS={m['epoch_rms_uv'][idx]:.1f} µV, "
            f"EpochRZ={m['epoch_rms_robustz'][idx]:.2f}"
        )

        # Plot epoch in a new window (non-blocking)
        self._plot_epoch(idx)

    def _plot_epoch(self, idx):
        # close prior figures to reduce clutter
        plt.close("all")

        # epochs[idx] returns an Epochs object containing one epoch
        ep = self.epochs[idx]

        # Use MNE's built-in epoch plotting
        # show=False prevents blocking; then matplotlib draws it
        fig = ep.plot(
            n_channels=min(20, len(self.epochs.ch_names)),
            scalings="auto",
            title=f"{self.participant} - Epoch {idx}",
            show=False
        )
        # In some MNE versions, ep.plot returns a Figure or list of Figures
        if isinstance(fig, list):
            for f in fig:
                f.show()
        else:
            fig.show()

        plt.show(block=False)

    def keep(self):
        idx = self.flagged[self._i]
        self.decisions[idx] = "keep"
        self._i += 1
        self.after(1, self._show_current)

    def drop(self):
        idx = self.flagged[self._i]
        self.decisions[idx] = "drop"
        self._i += 1
        self.after(1, self._show_current)

    def keep_all(self):
        self._mode = "keep_all"
        self._show_current()

    def drop_all(self):
        self._mode = "drop_all"
        self._show_current()

    def stop(self):
        self._stop = True
        self._finish()

    def _finish(self):
        self.grab_release()
        self.destroy()


# ============================================================
# PROCESS ONE PARTICIPANT WITH MANUAL REVIEW
# ============================================================
def safe_base_from_fif(path: str) -> str:
    base = os.path.basename(path)
    if base.endswith("_cleaned-epo.fif"):
        return base[:-len("_cleaned-epo.fif")]
    if base.endswith(".fif"):
        return base[:-len(".fif")]
    return base


def process_one_participant_manual(parent, fpath: str, out_dir: str, settings: dict):
    subj = safe_base_from_fif(fpath)
    print(f"\nProcessing: {subj}")

    epochs = mne.read_epochs(fpath, preload=True, verbose="ERROR")
    data_v = epochs.get_data()
    ch_names = epochs.ch_names

    flag_mask, metrics, reasons = flag_outlier_epochs(
        data_v=data_v,
        ch_names=ch_names,
        ptp_thresh_uv=settings["ptp_thresh_uv"],
        maxabs_thresh_uv=settings["maxabs_thresh_uv"],
        rms_robust_z_thresh=settings["rms_rz_thresh"],
        use_ch_rms_robust_z=settings["use_ch_rms_rz"],
        ch_rms_robust_z_thresh=settings["ch_rz_thresh"],
    )

    flagged = np.where(flag_mask)[0].tolist()
    print(f"  Flagged: {len(flagged)} / {len(epochs)}")

    # If nothing flagged, just save copy and decision file
    decisions = {}
    if len(flagged) > 0:
        dlg = EpochReviewDialog(parent, epochs, flagged, reasons, metrics, subj)
        decisions = dlg.decisions

    # Any flagged not reviewed explicitly -> default KEEP
    for idx in flagged:
        if idx not in decisions:
            decisions[idx] = "keep"

    drop_idx = sorted([idx for idx, d in decisions.items() if d == "drop"])

    # Save cleaned epochs
    epochs_clean = epochs.copy()
    if drop_idx:
        epochs_clean.drop(drop_idx, reason="manual_outlier")

    os.makedirs(out_dir, exist_ok=True)
    out_fif = os.path.join(out_dir, f"{subj}{OUT_SUFFIX}")
    out_dec = os.path.join(out_dir, f"{subj}_manual_outlier_decisions.csv")

    epochs_clean.save(out_fif, overwrite=True, verbose="ERROR")

    # Decision report (include ALL epochs or only flagged? We'll include flagged + decision)
    rows = []
    for idx in flagged:
        rows.append({
            "participant": subj,
            "epoch_index": idx,
            "decision": decisions.get(idx, "keep"),
            "reasons": "; ".join(reasons[idx]) if reasons[idx] else "",
            "ptp_max_uV": float(metrics["ptp_max_uv"][idx]),
            "maxabs_max_uV": float(metrics["maxabs_max_uv"][idx]),
            "epoch_rms_uV": float(metrics["epoch_rms_uv"][idx]),
            "epoch_rms_robustZ": float(metrics["epoch_rms_robustz"][idx]),
        })
    pd.DataFrame(rows).to_csv(out_dec, index=False)

    return {
        "participant": subj,
        "n_epochs_orig": int(len(epochs)),
        "n_flagged": int(len(flagged)),
        "n_dropped_manual": int(len(drop_idx)),
        "out_fif": out_fif,
        "decisions_csv": out_dec
    }


# ============================================================
# GUI APP
# ============================================================
class App(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Manual Epoch Outlier Review (Batch)")
        self.resizable(False, False)

        self.in_dir = tk.StringVar(value="")
        self.out_dir = tk.StringVar(value="")

        self.ptp_thresh = tk.StringVar(value=str(DEFAULT_PTP_THRESH_UV))
        self.maxabs_thresh = tk.StringVar(value=str(DEFAULT_MAXABS_THRESH_UV))
        self.rms_rz_thresh = tk.StringVar(value=str(DEFAULT_RMS_ROBUST_Z_THRESH))

        self.use_ch_rz = tk.BooleanVar(value=DEFAULT_USE_CH_RMS_ROBUST_Z)
        self.ch_rz_thresh = tk.StringVar(value=str(DEFAULT_CH_RMS_ROBUST_Z_THRESH))

        self.status = tk.StringVar(value="Ready.")
        self._build()

    def _build(self):
        frm = ttk.Frame(self, padding=10)
        frm.pack(fill="both", expand=True)

        ttk.Label(frm, text="Input folder (contains *_cleaned-epo.fif files)").grid(row=0, column=0, sticky="w")
        ttk.Entry(frm, textvariable=self.in_dir, width=60).grid(row=1, column=0, padx=8, pady=6, sticky="w")
        ttk.Button(frm, text="Browse…", command=self.pick_in).grid(row=1, column=1, sticky="w")

        ttk.Label(frm, text="Output folder").grid(row=2, column=0, sticky="w")
        ttk.Entry(frm, textvariable=self.out_dir, width=60).grid(row=3, column=0, padx=8, pady=6, sticky="w")
        ttk.Button(frm, text="Browse…", command=self.pick_out).grid(row=3, column=1, sticky="w")

        ttk.Separator(frm).grid(row=4, column=0, columnspan=2, sticky="ew", pady=(10, 10))

        ttk.Label(frm, text="Flagging Rules (these ONLY flag; you decide keep/drop)").grid(row=5, column=0, sticky="w")

        grid = ttk.Frame(frm)
        grid.grid(row=6, column=0, columnspan=2, sticky="w", padx=2, pady=6)

        ttk.Label(grid, text="Peak-to-peak threshold (µV) [blank disables]:").grid(row=0, column=0, sticky="w")
        ttk.Entry(grid, textvariable=self.ptp_thresh, width=10).grid(row=0, column=1, sticky="w", padx=(6, 18))

        ttk.Label(grid, text="Max |amplitude| threshold (µV) [blank disables]:").grid(row=0, column=2, sticky="w")
        ttk.Entry(grid, textvariable=self.maxabs_thresh, width=10).grid(row=0, column=3, sticky="w", padx=(6, 0))

        ttk.Label(grid, text="Epoch RMS robust-z threshold:").grid(row=1, column=0, sticky="w", pady=(10, 0))
        ttk.Entry(grid, textvariable=self.rms_rz_thresh, width=10).grid(row=1, column=1, sticky="w", padx=(6, 18), pady=(10, 0))

        ttk.Checkbutton(grid, text="Also flag if any channel RMS robust-z exceeds threshold", variable=self.use_ch_rz)\
            .grid(row=1, column=2, sticky="w", pady=(10, 0))

        ttk.Label(grid, text="Channel RMS robust-z threshold:").grid(row=2, column=2, sticky="w", pady=(10, 0))
        ttk.Entry(grid, textvariable=self.ch_rz_thresh, width=10).grid(row=2, column=3, sticky="w", padx=(6, 0), pady=(10, 0))

        ttk.Separator(frm).grid(row=7, column=0, columnspan=2, sticky="ew", pady=(10, 10))

        ttk.Button(frm, text="Run manual review (batch)", command=self.run).grid(row=8, column=0, sticky="w")
        ttk.Label(frm, textvariable=self.status, wraplength=560).grid(row=9, column=0, columnspan=2, sticky="w", pady=(8, 0))

    def pick_in(self):
        d = filedialog.askdirectory(title="Select input folder")
        if d:
            self.in_dir.set(d)
            if not self.out_dir.get().strip():
                self.out_dir.set(os.path.join(d, "manual_outliers_removed"))

    def pick_out(self):
        d = filedialog.askdirectory(title="Select output folder")
        if d:
            self.out_dir.set(d)

    @staticmethod
    def _parse_float_or_none(s: str):
        s = s.strip()
        if s == "":
            return None
        return float(s)

    def run(self):
        in_dir = self.in_dir.get().strip()
        out_dir = self.out_dir.get().strip()

        if not in_dir or not os.path.isdir(in_dir):
            messagebox.showerror("Error", "Please select a valid input folder.")
            return
        if not out_dir:
            messagebox.showerror("Error", "Please select a valid output folder.")
            return

        fif_files = sorted(glob.glob(os.path.join(in_dir, EPOCH_FIF_GLOB)))
        if not fif_files:
            messagebox.showerror("No files found", f"No files matching {EPOCH_FIF_GLOB} in:\n{in_dir}")
            return

        try:
            ptp = self._parse_float_or_none(self.ptp_thresh.get())
            maxabs = self._parse_float_or_none(self.maxabs_thresh.get())
            rms_rz = float(self.rms_rz_thresh.get().strip())
            use_ch = bool(self.use_ch_rz.get())
            ch_rz = float(self.ch_rz_thresh.get().strip())
        except Exception as e:
            messagebox.showerror("Invalid settings", f"Could not parse thresholds:\n\n{e}")
            return

        settings = {
            "ptp_thresh_uv": ptp,
            "maxabs_thresh_uv": maxabs,
            "rms_rz_thresh": rms_rz,
            "use_ch_rms_rz": use_ch,
            "ch_rz_thresh": ch_rz,
        }

        summary = []
        for i, fpath in enumerate(fif_files, start=1):
            self.status.set(f"Participant {i}/{len(fif_files)}: {os.path.basename(fpath)} (reviewing flagged epochs...)")
            self.update_idletasks()

            try:
                res = process_one_participant_manual(self, fpath, out_dir, settings)
                summary.append(res)
            except Exception as e:
                messagebox.showerror("Error", f"Failed on:\n{fpath}\n\n{e}")
                self.status.set("Stopped due to error.")
                return

        os.makedirs(out_dir, exist_ok=True)
        batch_csv = os.path.join(out_dir, "batch_manual_outlier_summary.csv")
        pd.DataFrame(summary).to_csv(batch_csv, index=False)

        self.status.set(f"Done. Saved:\n{batch_csv}")
        messagebox.showinfo("Finished", f"Batch complete.\n\nSaved summary:\n{batch_csv}")


if __name__ == "__main__":
    app = App()
    app.mainloop()
